{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d14d29ebd3592ecb",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Late Interaction Text Embedding Models\n",
    "\n",
    "As of version 0.3.0 FastEmbed supports Late Interaction Text Embedding Models and currently available with one of the most popular embedding model of the family - ColBERT.\n",
    "\n",
    "## What is a Late Interaction Text Embedding Model?\n",
    "\n",
    "Late Interaction Text Embedding Model is a kind of information retrieval model which performs query and documents interactions at the scoring stage.\n",
    "In order to better understand it, we can compare it to the models without interaction.  \n",
    "For instance, if you take a sentence-transformer model, compute embeddings for your documents, compute embeddings for your queries, and just compare them by cosine similarity, then you're retrieving points without interaction.\n",
    "\n",
    "It is a pretty much easy and straightforward approach, however we might be sacrificing some precision due to its simplicity. It is caused by several facts: \n",
    "- there is no interaction between queries and documents at the early stage (embedding generation) nor at the late stage (during scoring). \n",
    "- we are trying to encapsulate all the document information in only one pooled embedding, and obviously, some information might be lost.\n",
    "\n",
    "Late Interaction Text Embedding models are trying to address it by computing embeddings for each token in queries and documents, and then finding the most similar ones via model specific operation, e.g. ColBERT (Contextual Late Interaction over BERT) uses MaxSim operation.\n",
    "With this approach we can have not only a better representation of the documents, but also make queries and documents more aware one of another.\n",
    "\n",
    "For more information on ColBERT and MaxSim operation, you can check out [this blogpost](https://jina.ai/news/what-is-colbert-and-late-interaction-and-why-they-matter-in-search/) by Jina AI.\n",
    "\n",
    "## ColBERT in FastEmbed\n",
    "\n",
    "FastEmbed provides a simple way to use ColBERT model, similar to the ones it has with `TextEmbedding`.\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7f1053b17c810be5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-03T17:20:26.927643Z",
     "start_time": "2024-06-03T17:20:25.128994Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/joein/work/qdrant/fastembed/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "data": {
      "text/plain": "[{'model': 'colbert-ir/colbertv2.0',\n  'dim': 128,\n  'description': 'Late interaction model',\n  'size_in_GB': 0.44,\n  'sources': {'hf': 'colbert-ir/colbertv2.0'},\n  'model_file': 'model.onnx'}]"
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fastembed import LateInteractionTextEmbedding\n",
    "\n",
    "LateInteractionTextEmbedding.list_supported_models()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c2c15893df422631",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-03T17:23:35.764183Z",
     "start_time": "2024-06-03T17:23:21.630277Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]\n",
      "config.json: 100%|██████████| 743/743 [00:00<00:00, 4.56MB/s]\n",
      "\n",
      "tokenizer_config.json: 100%|██████████| 405/405 [00:00<00:00, 3.34MB/s]\n",
      "Fetching 5 files:  20%|██        | 1/5 [00:00<00:01,  3.64it/s]\n",
      "tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]\u001b[A\n",
      "\n",
      "special_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 727kB/s]\n",
      "\n",
      "tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 1.48MB/s]\u001b[A\n",
      "\n",
      "model.onnx:   0%|          | 0.00/436M [00:00<?, ?B/s]\u001b[A\n",
      "model.onnx:   2%|▏         | 10.5M/436M [00:00<00:34, 12.2MB/s]\u001b[A\n",
      "model.onnx:   5%|▍         | 21.0M/436M [00:01<00:20, 20.3MB/s]\u001b[A\n",
      "model.onnx:   7%|▋         | 31.5M/436M [00:01<00:15, 25.7MB/s]\u001b[A\n",
      "model.onnx:  10%|▉         | 41.9M/436M [00:01<00:13, 29.4MB/s]\u001b[A\n",
      "model.onnx:  12%|█▏        | 52.4M/436M [00:01<00:12, 31.9MB/s]\u001b[A\n",
      "model.onnx:  14%|█▍        | 62.9M/436M [00:02<00:11, 33.7MB/s]\u001b[A\n",
      "model.onnx:  17%|█▋        | 73.4M/436M [00:02<00:10, 34.9MB/s]\u001b[A\n",
      "model.onnx:  19%|█▉        | 83.9M/436M [00:02<00:09, 35.5MB/s]\u001b[A\n",
      "model.onnx:  22%|██▏       | 94.4M/436M [00:03<00:09, 36.1MB/s]\u001b[A\n",
      "model.onnx:  24%|██▍       | 105M/436M [00:03<00:09, 36.6MB/s] \u001b[A\n",
      "model.onnx:  26%|██▋       | 115M/436M [00:03<00:08, 36.9MB/s]\u001b[A\n",
      "model.onnx:  29%|██▉       | 126M/436M [00:03<00:08, 37.1MB/s]\u001b[A\n",
      "model.onnx:  31%|███▏      | 136M/436M [00:04<00:08, 37.3MB/s]\u001b[A\n",
      "model.onnx:  34%|███▎      | 147M/436M [00:04<00:07, 37.4MB/s]\u001b[A\n",
      "model.onnx:  36%|███▌      | 157M/436M [00:04<00:07, 37.4MB/s]\u001b[A\n",
      "model.onnx:  38%|███▊      | 168M/436M [00:05<00:07, 37.5MB/s]\u001b[A\n",
      "model.onnx:  41%|████      | 178M/436M [00:05<00:06, 37.6MB/s]\u001b[A\n",
      "model.onnx:  43%|████▎     | 189M/436M [00:05<00:06, 37.6MB/s]\u001b[A\n",
      "model.onnx:  46%|████▌     | 199M/436M [00:05<00:06, 37.6MB/s]\u001b[A\n",
      "model.onnx:  48%|████▊     | 210M/436M [00:06<00:06, 37.5MB/s]\u001b[A\n",
      "model.onnx:  50%|█████     | 220M/436M [00:06<00:05, 37.5MB/s]\u001b[A\n",
      "model.onnx:  53%|█████▎    | 231M/436M [00:06<00:05, 37.6MB/s]\u001b[A\n",
      "model.onnx:  55%|█████▌    | 241M/436M [00:06<00:05, 37.6MB/s]\u001b[A\n",
      "model.onnx:  58%|█████▊    | 252M/436M [00:07<00:04, 37.6MB/s]\u001b[A\n",
      "model.onnx:  60%|██████    | 262M/436M [00:07<00:04, 37.7MB/s]\u001b[A\n",
      "model.onnx:  63%|██████▎   | 273M/436M [00:07<00:04, 37.7MB/s]\u001b[A\n",
      "model.onnx:  65%|██████▍   | 283M/436M [00:08<00:04, 36.0MB/s]\u001b[A\n",
      "model.onnx:  67%|██████▋   | 294M/436M [00:08<00:03, 36.4MB/s]\u001b[A\n",
      "model.onnx:  70%|██████▉   | 304M/436M [00:08<00:03, 36.8MB/s]\u001b[A\n",
      "model.onnx:  72%|███████▏  | 315M/436M [00:08<00:03, 37.0MB/s]\u001b[A\n",
      "model.onnx:  75%|███████▍  | 325M/436M [00:09<00:02, 37.3MB/s]\u001b[A\n",
      "model.onnx:  77%|███████▋  | 336M/436M [00:09<00:03, 30.8MB/s]\u001b[A\n",
      "model.onnx:  79%|███████▉  | 346M/436M [00:10<00:02, 32.6MB/s]\u001b[A\n",
      "model.onnx:  82%|████████▏ | 357M/436M [00:10<00:02, 33.9MB/s]\u001b[A\n",
      "model.onnx:  84%|████████▍ | 367M/436M [00:10<00:01, 34.8MB/s]\u001b[A\n",
      "model.onnx:  87%|████████▋ | 377M/436M [00:10<00:01, 35.7MB/s]\u001b[A\n",
      "model.onnx:  89%|████████▉ | 388M/436M [00:11<00:01, 36.2MB/s]\u001b[A\n",
      "model.onnx:  91%|█████████▏| 398M/436M [00:11<00:01, 36.6MB/s]\u001b[A\n",
      "model.onnx:  94%|█████████▍| 409M/436M [00:11<00:00, 36.9MB/s]\u001b[A\n",
      "model.onnx:  96%|█████████▌| 419M/436M [00:11<00:00, 37.1MB/s]\u001b[A\n",
      "model.onnx:  99%|█████████▊| 430M/436M [00:12<00:00, 37.3MB/s]\u001b[A\n",
      "model.onnx: 100%|██████████| 436M/436M [00:12<00:00, 35.1MB/s]\u001b[A\n",
      "Fetching 5 files: 100%|██████████| 5/5 [00:13<00:00,  2.68s/it]\n"
     ]
    }
   ],
   "source": [
    "embedding_model = LateInteractionTextEmbedding(\"colbert-ir/colbertv2.0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e560b5fa7d63bea3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-03T17:39:33.400876Z",
     "start_time": "2024-06-03T17:39:33.397431Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "documents = [\n",
    "    \"ColBERT is a late interaction text embedding model, however, there are also other models such as TwinBERT.\",\n",
    "    \"On the contrary to the late interaction models, the early interaction models contains interaction steps at embedding generation process\",\n",
    "]\n",
    "queries = [\n",
    "    \"Are there any other late interaction text embedding models except ColBERT?\",\n",
    "    \"What is the difference between late interaction and early interaction text embedding models?\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "347ad924a3449743",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "*NOTE*: ColBERT computes query and documents embeddings differently, make sure to use the corresponding methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "496fbf51e4eaaae",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-03T17:39:34.379885Z",
     "start_time": "2024-06-03T17:39:34.316257Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "document_embeddings = list(\n",
    "    embedding_model.embed(documents)\n",
    ")  # embed and qury_embed return generators,\n",
    "# which we need to evaluate by writing them to a list\n",
    "query_embeddings = list(embedding_model.query_embed(queries))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "50595bb0498f0c7c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-03T17:39:34.793528Z",
     "start_time": "2024-06-03T17:39:34.788545Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": "((26, 128), (32, 128))"
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "document_embeddings[0].shape, query_embeddings[0].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13e43f2c24a7d5fc",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Don't worry about query embeddings having the bigger shape in this case. \n",
    "ColBERT authors recommend to pad queries with [MASK] tokens to 32 tokens.\n",
    "They also recommends to truncate queries to 32 tokens, however we don't do that in FastEmbed, so you can put some straight into the queries."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb1a4011effd3699",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## MaxSim operator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9ea4cf82521f2de",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Qdrant will support ColBERT as of the next version (v1.10), however, at the moment, you can compute embedding similarities manually.   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "f84392f63d2c6076",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-03T17:39:36.431622Z",
     "start_time": "2024-06-03T17:39:36.427363Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_relevance_scores(query_embedding: np.array, document_embeddings: np.array, k: int):\n",
    "    \"\"\"\n",
    "    Compute relevance scores for top-k documents given a query.\n",
    "\n",
    "    :param query_embedding: Numpy array representing the query embedding, shape: [num_query_terms, embedding_dim]\n",
    "    :param document_embeddings: Numpy array representing embeddings for documents, shape: [num_documents, max_doc_length, embedding_dim]\n",
    "    :param k: Number of top documents to return\n",
    "    :return: Indices of the top-k documents based on their relevance scores\n",
    "    \"\"\"\n",
    "    # Compute batch dot-product of query_embedding and document_embeddings\n",
    "    # Resulting shape: [num_documents, num_query_terms, max_doc_length]\n",
    "    scores = np.matmul(query_embedding, document_embeddings.transpose(0, 2, 1))\n",
    "\n",
    "    # Apply max-pooling across document terms (axis=2) to find the max similarity per query term\n",
    "    # Shape after max-pool: [num_documents, num_query_terms]\n",
    "    max_scores_per_query_term = np.max(scores, axis=2)\n",
    "\n",
    "    # Sum the scores across query terms to get the total score for each document\n",
    "    # Shape after sum: [num_documents]\n",
    "    total_scores = np.sum(max_scores_per_query_term, axis=1)\n",
    "\n",
    "    # Sort the documents based on their total scores and get the indices of the top-k documents\n",
    "    sorted_indices = np.argsort(total_scores)[::-1][:k]\n",
    "\n",
    "    return sorted_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c61d07bed7b60e35",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-03T17:39:37.053383Z",
     "start_time": "2024-06-03T17:39:37.050926Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sorted document indices: [0 1]\n"
     ]
    }
   ],
   "source": [
    "sorted_indices = compute_relevance_scores(\n",
    "    np.array(query_embeddings[0]), np.array(document_embeddings), k=3\n",
    ")\n",
    "print(\"Sorted document indices:\", sorted_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b24df2569970d9e8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-03T17:40:52.276846Z",
     "start_time": "2024-06-03T17:40:52.273789Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: Are there any other late interaction text embedding models except ColBERT?\n",
      "Document: ColBERT is a late interaction text embedding model, however, there are also other models such as TwinBERT.\n",
      "Document: On the contrary to the late interaction models, the early interaction models contains interaction steps at embedding generation process\n"
     ]
    }
   ],
   "source": [
    "print(f\"Query: {queries[0]}\")\n",
    "for index in sorted_indices:\n",
    "    print(f\"Document: {documents[index]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6de537c37aff3927",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Use-case recommendation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37e3525d3259cd2b",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Despite ColBERT allows to compute embeddings independently and spare some workload offline, it still computes more resources than no interaction models. Due to this, it might be more reasonable to use ColBERT not as a first-stage retriever, but as a re-ranker.\n",
    "\n",
    "The first-stage retriever would then be a no-interaction model, which e.g. retrieves first 100 or 500 examples, and leave the final ranking to the ColBERT model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfa922793454b4ad",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
